{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# simplified Confident Learning Tutorial\n",
    "*Author: Curtis G. Northcutt, cgn@mit.edu*\n",
    "\n",
    "In this tutorial, we show how to implement confident learning without using cleanlab (for the most part). \n",
    "This tutorial is to confident learning what this tutorial https://pytorch.org/tutorials/beginner/examples_tensor/two_layer_net_numpy.html\n",
    "is to deep learning.\n",
    "\n",
    "The actual implementations in cleanlab are complex because they support parallel processing, numerous type and input checks, lots of hyper-parameter settings, lots of utilities to make things work smoothly for all types of inputs, and ancillary functions.\n",
    "\n",
    "I ignore all of that here and provide you a bare-bones implementation using mostly for-loops and some numpy.\n",
    "Here we'll do two simple things:\n",
    "1. Compute the confident joint which fully characterizes all label noise.\n",
    "2. Find the indices of all label errors, ordered by likelihood of being an error.\n",
    "\n",
    "## INPUT (stuff we need beforehand):\n",
    "1. s - These are the noisy labels. This is an np.array of noisy labels, shape (n,1)\n",
    "2. psx - These are the out-of-sample holdout predicted probabilities for every example in your dataset. This is an np.array (2d) of probabilities, shape (n, m)\n",
    "\n",
    "## OUTPUT (what this returns):\n",
    "1. confident_joint - an (m, m) np.array matrix characterizing all the label error counts for every pair of labels.\n",
    "2. label_errors_idx - a numpy array comprised of indices of every label error, ordered by likelihood of being a label error.\n",
    "\n",
    "In this tutorial we use the handwritten digits dataset as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function, absolute_import, division, with_statement\n",
    "import cleanlab\n",
    "import numpy as np\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "# To silence convergence warnings caused by using a weak\n",
    "# logistic regression classifier on image data\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")\n",
    "np.random.seed(477)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Handwritten digits datasets number of classes: 10\n",
      "Handwritten digits datasets number of examples: 1797\n",
      "\n",
      "Indices of actual label errors:\n",
      " [  24   27   42   47   50   54   58   62   85   94  105  107  110  157\n",
      "  159  182  201  211  220  223  232  300  310  312  336  343  348  350\n",
      "  358  366  375  387  400  418  430  462  469  483  502  538  555  559\n",
      "  562  564  614  623  626  647  649  674  694  706  708  761  778  796\n",
      "  807  830  834  840  883  887  980  996 1007 1044 1064 1083 1084 1104\n",
      " 1105 1111 1122 1176 1192 1257 1261 1265 1301 1312 1324 1327 1341 1375\n",
      " 1472 1483 1492 1493 1496 1524 1534 1584 1592 1643 1646 1648 1675 1707\n",
      " 1744 1779]\n"
     ]
    }
   ],
   "source": [
    "# STEP 0 - Get some real digits data. Add a bunch of label errors. Get probs.\n",
    "\n",
    "# Get handwritten digits data\n",
    "X = load_digits()['data']\n",
    "y = load_digits()['target']\n",
    "print('Handwritten digits datasets number of classes:', len(np.unique(y)))\n",
    "print('Handwritten digits datasets number of examples:', len(y))\n",
    "\n",
    "# Add lots of errors to labels\n",
    "NUM_ERRORS = 100\n",
    "s = np.array(y)\n",
    "error_indices = np.random.choice(len(s), NUM_ERRORS, replace=False)\n",
    "for i in error_indices:\n",
    "    # Switch to some wrong label thats a different class\n",
    "    wrong_label = np.random.choice(np.delete(range(10), s[i]))\n",
    "    s[i] = wrong_label\n",
    "\n",
    "# Confirm that we indeed added NUM_ERRORS label errors\n",
    "assert (len(s) - sum(s == y) == NUM_ERRORS)\n",
    "actual_label_errors = np.arange(len(y))[s != y]\n",
    "print('\\nIndices of actual label errors:\\n', actual_label_errors)\n",
    "\n",
    "# To keep the tutorial short, we use cleanlab to get the \n",
    "# out-of-sample predicted probabilities using cross-validation\n",
    "# with a very simple, non-optimized logistic regression classifier\n",
    "psx = cleanlab.latent_estimation.estimate_cv_predicted_probabilities(\n",
    "    X, s, clf=LogisticRegression(max_iter=1000, multi_class='auto', solver='lbfgs'))\n",
    "\n",
    "# Now we have our noisy labels s and predicted probabilities psx.\n",
    "# That's all we need for confident learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Joint Label Noise Distribution Matrix P(s,y) of shape (10, 10)\n",
      " p(s,y)\ty=0\ty=1\ty=2\ty=3\ty=4\ty=5\ty=6\ty=7\ty=8\ty=9\n",
      "\t---\t---\t---\t---\t---\t---\t---\t---\t---\t---\n",
      "s=0 |\t170\t2\t2\t0\t1\t4\t0\t0\t2\t0\n",
      "s=1 |\t2\t167\t3\t1\t3\t1\t0\t0\t3\t0\n",
      "s=2 |\t0\t3\t157\t4\t0\t1\t0\t1\t0\t1\n",
      "s=3 |\t0\t0\t3\t176\t0\t0\t1\t1\t3\t1\n",
      "s=4 |\t1\t1\t0\t0\t165\t1\t0\t1\t2\t3\n",
      "s=5 |\t0\t2\t5\t1\t4\t165\t2\t0\t1\t6\n",
      "s=6 |\t0\t4\t1\t1\t2\t0\t167\t3\t1\t5\n",
      "s=7 |\t0\t1\t0\t0\t4\t3\t3\t163\t0\t1\n",
      "s=8 |\t0\t3\t4\t1\t4\t1\t3\t4\t164\t2\n",
      "s=9 |\t1\t1\t1\t3\t5\t1\t2\t5\t4\t156\n",
      "\tTrace(matrix) = 1650\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# STEP 1 - Compute confident joint\n",
    "\n",
    "# Verify inputs\n",
    "s = np.asarray(s)\n",
    "psx = np.asarray(psx)\n",
    "\n",
    "# Find the number of unique classes if K is not given\n",
    "K = len(np.unique(s))\n",
    "\n",
    "# Estimate the probability thresholds for confident counting\n",
    "# You can specify these thresholds yourself if you want\n",
    "# as you may want to optimize them using a validation set.\n",
    "# By default (and provably so) they are set to the average class prob.\n",
    "thresholds = [np.mean(psx[:,k][s == k]) for k in range(K)] # P(s^=k|s=k)\n",
    "thresholds = np.asarray(thresholds)\n",
    "\n",
    "# Compute confident joint\n",
    "confident_joint = np.zeros((K, K), dtype = int)\n",
    "for i, row in enumerate(psx):\n",
    "    s_label = s[i]\n",
    "    # Find out how many classes each example is confidently labeled as\n",
    "    confident_bins = row >= thresholds - 1e-6\n",
    "    num_confident_bins = sum(confident_bins)\n",
    "    # If more than one conf class, inc the count of the max prob class\n",
    "    if num_confident_bins == 1:\n",
    "        confident_joint[s_label][np.argmax(confident_bins)] += 1\n",
    "    elif num_confident_bins > 1:\n",
    "        confident_joint[s_label][np.argmax(row)] += 1\n",
    "\n",
    "# Normalize confident joint (use cleanlab, trust me on this)\n",
    "confident_joint = cleanlab.latent_estimation.calibrate_confident_joint(\n",
    "    confident_joint, s)\n",
    "\n",
    "cleanlab.util.print_joint_matrix(confident_joint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Indices of label errors found by confident learning:\n",
      "Note label errors are sorted by likelihood of being an error\n",
      "but here we just sort them by index for comparison with above.\n",
      "[   5    9   24   42   43   46   47   50   54   58   62   69   77   85\n",
      "   94   95  105  107  110  123  157  159  182  191  201  211  220  223\n",
      "  227  232  267  300  310  312  325  336  343  348  350  358  361  366\n",
      "  375  384  387  418  421  430  462  483  492  502  521  538  547  555\n",
      "  559  562  614  623  626  647  649  674  694  706  708  745  757  761\n",
      "  769  770  778  794  795  807  830  834  840  883  887  920  980  996\n",
      " 1007 1022 1025 1038 1044 1064 1083 1084 1104 1105 1111 1118 1122 1146\n",
      " 1176 1192 1197 1233 1241 1257 1261 1265 1301 1312 1324 1327 1341 1375\n",
      " 1472 1483 1492 1493 1496 1524 1534 1553 1572 1575 1580 1584 1592 1597\n",
      " 1615 1618 1643 1645 1646 1648 1658 1660 1675 1707 1742 1744 1747 1779]\n"
     ]
    }
   ],
   "source": [
    "# STEP 2 - Find label errors\n",
    "\n",
    "# We arbitrarily choose at least 5 examples left in every class.\n",
    "# Regardless of whether some of them might be label errors.\n",
    "MIN_NUM_PER_CLASS = 5\n",
    "# Leave at least MIN_NUM_PER_CLASS examples per class.\n",
    "# NOTE prune_count_matrix is transposed (relative to confident_joint)\n",
    "prune_count_matrix = cleanlab.pruning.keep_at_least_n_per_class(\n",
    "    prune_count_matrix=confident_joint.T,\n",
    "    n=MIN_NUM_PER_CLASS,\n",
    ")\n",
    "\n",
    "s_counts = np.bincount(s)\n",
    "noise_masks_per_class = []\n",
    "# For each row in the transposed confident joint\n",
    "for k in range(K):\n",
    "    noise_mask = np.zeros(len(psx), dtype=bool)\n",
    "    psx_k = psx[:, k]\n",
    "    if s_counts[k] > MIN_NUM_PER_CLASS:  # Don't prune if not MIN_NUM_PER_CLASS\n",
    "        for j in range(K):  # noisy label index (k is the true label index)\n",
    "            if k != j:  # Only prune for noise rates, not diagonal entries\n",
    "                num2prune = prune_count_matrix[k][j]\n",
    "                if num2prune > 0:\n",
    "                    # num2prune'th largest p(classk) - p(class j)\n",
    "                    # for x with noisy label j\n",
    "                    margin = psx_k - psx[:, j]\n",
    "                    s_filter = s == j\n",
    "                    threshold = -np.partition(\n",
    "                        -margin[s_filter], num2prune - 1\n",
    "                    )[num2prune - 1]\n",
    "                    noise_mask = noise_mask | (s_filter & (margin >= threshold))\n",
    "        noise_masks_per_class.append(noise_mask)\n",
    "    else:\n",
    "        noise_masks_per_class.append(np.zeros(len(s), dtype=bool))\n",
    "\n",
    "# Boolean label error mask\n",
    "label_errors_bool = np.stack(noise_masks_per_class).any(axis=0)\n",
    "\n",
    " # Remove label errors if given label == model prediction\n",
    "for i, pred_label in enumerate(psx.argmax(axis=1)):\n",
    "    # np.all let's this work for multi_label and single label\n",
    "    if label_errors_bool[i] and np.all(pred_label == s[i]):\n",
    "        label_errors_bool[i] = False\n",
    "\n",
    "# Convert boolean mask to an ordered list of indices for label errors\n",
    "label_errors_idx = np.arange(len(s))[label_errors_bool]\n",
    "# self confidence is the holdout probability that an example\n",
    "# belongs to its given class label\n",
    "self_confidence = np.array(\n",
    "    [np.mean(psx[i][s[i]]) for i in label_errors_idx]\n",
    ")\n",
    "margin = self_confidence - psx[label_errors_bool].max(axis=1)\n",
    "label_errors_idx = label_errors_idx[np.argsort(margin)]\n",
    "\n",
    "print('Indices of label errors found by confident learning:')\n",
    "print('Note label errors are sorted by likelihood of being an error')\n",
    "print('but here we just sort them by index for comparison with above.')\n",
    "print(np.array(sorted(label_errors_idx)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "% actual errors that confident learning found: 95%\n",
      "% confident learning errors that are actual errors: 68%\n"
     ]
    }
   ],
   "source": [
    "score = sum([e in label_errors_idx for e in actual_label_errors]) / NUM_ERRORS\n",
    "print('% actual errors that confident learning found: {:.0%}'.format(score))\n",
    "score = sum([e in actual_label_errors for e in label_errors_idx]) / len(label_errors_idx)\n",
    "print('% confident learning errors that are actual errors: {:.0%}'.format(score))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
